{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image Classification with DJL Spark Support\n",
    "\n",
    "In this example, we will use Jupyter Notebook to run image Classification with DJL Spark extension on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
    "\n",
    "[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)\n",
    "\n",
    "After that, you can start with DJL's Scala notebook.\n",
    "\n",
    "\n",
    "## Import dependencies\n",
    "\n",
    "Firstly, let's import the depdendencies we need. We choose to use DJL PyTorch as our backend engine. You can also switch to MXNet by uncommenting the two lines for MXNet and comment PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import $ivy.`org.apache.spark::spark-sql:3.3.2`\n",
    "import $ivy.`ai.djl:api:0.27.0`\n",
    "import $ivy.`ai.djl.spark:spark_2.12:0.27.0`\n",
    "import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.27.0`\n",
    "import $ivy.`ai.djl.pytorch:pytorch-native-cpu-precxx11:2.1.1`\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can import the packages we need to use. In the last two lines, we disabled the Spark logging in order to avoid polluting your cell outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.apache.spark.sql.NotebookSparkSession\n",
    "import ai.djl.spark.task.vision.ImageClassifier\n",
    "import org.apache.spark.sql.SparkSession\n",
    "\n",
    "import org.apache.log4j.{Level, Logger}\n",
    "Logger.getLogger(\"org\").setLevel(Level.OFF) // avoid too much message popping out\n",
    "Logger.getLogger(\"ai\").setLevel(Level.OFF) // avoid too much message popping out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Spark application\n",
    "\n",
    "We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Create Spark session\n",
    "val spark = {\n",
    "  NotebookSparkSession.builder()\n",
    "    .master(\"local[*]\")\n",
    "    .getOrCreate()\n",
    "}\n",
    "\n",
    "// Use Pytorch precxx11\n",
    "System.setProperty(\"PYTORCH_PRECXX11\", \"true\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to load the images from the local folder using Spark library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var df = spark.read.format(\"image\").option(\"dropInvalid\", true).load(\"../image-classification/images\")\n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can run inference on these images. All we need to do is to create a `ImageClassifier` and run inference with DJL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.select(\"image.*\").filter(\"nChannels=3\") // The model expects RGB images\n",
    "val classifier = new ImageClassifier()\n",
    "  .setInputCols(Array(\"origin\", \"height\", \"width\", \"nChannels\", \"mode\", \"data\"))\n",
    "  .setOutputCol(\"prediction\")\n",
    "  .setEngine(\"PyTorch\")\n",
    "  .setModelUrl(\"djl://ai.djl.pytorch/resnet\")\n",
    "  .setTopK(2)\n",
    "var outputDf = classifier.classify(df)\n",
    "outputDf.select(\"origin\", \"prediction.top_k\").show(truncate=false)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Scala",
   "language": "scala",
   "name": "scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".sc",
   "mimetype": "text/x-scala",
   "name": "scala",
   "nbconvert_exporter": "script",
   "version": "2.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
